import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftConfig, PeftModel
import json
import concurrent.futures
import re
from tqdm import tqdm
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = '/data/home/zhanghx/code/DataContaminate/ckpts/model/newllama-7b/seed_1/complete_0.77_new'

config = PeftConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
lora_model = PeftModel.from_pretrained(model, model_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = lora_model

model.to(device)


def predict_next_part_with_llama(input_text, last_text):
    inputs = tokenizer.encode(input_text, return_tensors='pt').to(device)
    max_new_tokens = inputs.shape[1]
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, do_sample=False)
    predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # print("complete_text: ", predicted_text)
    # print("last_text: ", last_text)
    # print("----------------------------over-----------------------------")
    return predicted_text

def form_data(data_path):
    with open(data_path, 'r') as f:
        data = json.load(f)
    return data

def generate(dataset):
    labels = []
    texts = []
    last_text_total = []
    first_text_total = []
    for i in range(len(dataset)):
        text = dataset[i]['text']
        words = text.split()
        half_length = len(words) // 2
        input_text = " ".join(words[:half_length])
        response = " ".join(words[half_length:])
        
        text = f'''Below is an  an incomplete input, Write a response that appropriately completes the input.
        
        ### Input:
        {input_text}
        
        ### Response:
        '''
        labels.append(dataset[i]['label'])
        texts.append(text)
        last_text_total.append(response)
        first_text_total.append(input_text)
    return texts, last_text_total, labels, first_text_total
        
    
def parallel_predict(input_texts, last_text, num_workers=4):
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(predict_next_part_with_llama, text, last) for text, last in zip(input_texts, last_text)]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            results.append(future.result())
    return results
      
def write_to_json(first_text_total, last_text_total, complete_text, labels, path):
    assert len(first_text_total) == len(last_text_total) == len(complete_text) == len(labels), "error: length not equal"
    for i in range(len(complete_text)):
        ll = len(first_text_total[i].split())
        complete_text_words = complete_text[i].split()
        complete_text[i] = ' '.join(complete_text_words[:ll]) if len(complete_text_words) >= ll else complete_text[i]
        print(len(complete_text[i].split()))
    dic = {"first_text": first_text_total, "last_text": last_text_total, "complete_text": complete_text, "labels": labels}
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(dic, f, ensure_ascii=False, indent=4)

def extract_response(generated_text):
    results = []
    for result in generated_text:
        math = re.search(r'### Response:\n\s*(.*)', result)
        if math:
            results.append(math.group(1).strip())
        else:
            results.append(None)
            # raise ValueError(f"Error: no answer found in the generated text.")
    return results

if __name__ == '__main__':
    data_path = 'benchmarks/fine_tuning/test_data.json'
    dataset = form_data(data_path)
    texts, last_text_total, labels, first_text_total= generate(dataset)
    results = parallel_predict(texts, last_text_total)
    complete_text = extract_response(results)
    write_path = "complete_data/test_complete_text.json"
    write_to_json(first_text_total, last_text_total, complete_text, labels, write_path)
    
    
    